By Nadia Metoui
Faculty of Technology, Policy, and Management (TPM)
Learning Objectives
Implement and Examine Intrinsic Explainable ML Models.
Implement and Examine Post-Hoc Explainability Models.
Discuss explainability in a Multi-stakeholders Socio-Technical context.
Assignment Scenario
The healthcare sector is a high-risk and highly regulated sector. Many stakeholders are involved in ensuring these regulations are respected ad risks maintained to a minimum. Machin learning has many advantages and can be valuable for many health applications. It should, however, be transparent, trustworthy and trusted.
In this assignment, you will train a classifier to predict if a patient is at risk of diabetes or not. This model will be used by physicians before selecting candidates for a new drug trial. This drug might have adverse effects on individuals with risks of developing diabetes. Therefore, these individuals should not be on the list of candidates.
After training the model you should ensure the model and its results are explainable to the several stakeholders (including the development team, the hospital management, the doctors and the candidates/patients).
Concretely you will practice different types of explainability we saw in class (during Lecture and the Lab 5); you will observe the explanations of each type of explainer; and discuss their utility (for different stakeholders) and their limitations.
Assignment Steps
This assignment is composed of four parts.
Each part containes coding taskes To Code, and textual answers To Answer
*The total number of points out of 100 points will be normalized to calculate your average over 10 points
Sumbission Instructions
Install and Load the libraries for the Lab.
!pip install lime
!pip install shap
Requirement already satisfied: lime in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (0.2.0.1) Requirement already satisfied: scikit-learn>=0.18 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from lime) (1.0.2) Requirement already satisfied: tqdm in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from lime) (4.64.1) Requirement already satisfied: scipy in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from lime) (1.7.3) Requirement already satisfied: numpy in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from lime) (1.21.6) Requirement already satisfied: scikit-image>=0.12 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from lime) (0.19.3) Requirement already satisfied: matplotlib in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from lime) (3.5.3) Requirement already satisfied: networkx>=2.2 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-image>=0.12->lime) (2.6.3) Requirement already satisfied: PyWavelets>=1.1.1 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-image>=0.12->lime) (1.3.0) Requirement already satisfied: tifffile>=2019.7.26 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-image>=0.12->lime) (2021.11.2) Requirement already satisfied: pillow!=7.1.0,!=7.1.1,!=8.3.0,>=6.1.0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-image>=0.12->lime) (9.3.0) Requirement already satisfied: imageio>=2.4.1 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-image>=0.12->lime) (2.23.0) Requirement already satisfied: packaging>=20.0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-image>=0.12->lime) (21.3) Requirement already satisfied: threadpoolctl>=2.0.0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-learn>=0.18->lime) (3.1.0) Requirement already satisfied: joblib>=0.11 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-learn>=0.18->lime) (1.2.0) Requirement already satisfied: fonttools>=4.22.0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from matplotlib->lime) (4.38.0) Requirement already satisfied: pyparsing>=2.2.1 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from matplotlib->lime) (3.0.9) Requirement already satisfied: python-dateutil>=2.7 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from matplotlib->lime) (2.8.2) Requirement already satisfied: cycler>=0.10 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from matplotlib->lime) (0.11.0) Requirement already satisfied: kiwisolver>=1.0.1 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from matplotlib->lime) (1.4.4) Requirement already satisfied: colorama in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from tqdm->lime) (0.4.6) Requirement already satisfied: typing-extensions in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from kiwisolver>=1.0.1->matplotlib->lime) (4.4.0) Requirement already satisfied: six>=1.5 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from python-dateutil>=2.7->matplotlib->lime) (1.16.0) Requirement already satisfied: shap in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (0.41.0) Requirement already satisfied: scikit-learn in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (1.0.2) Requirement already satisfied: numpy in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (1.21.6) Requirement already satisfied: tqdm>4.25.0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (4.64.1) Requirement already satisfied: slicer==0.0.7 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (0.0.7) Requirement already satisfied: scipy in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (1.7.3) Requirement already satisfied: numba in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (0.56.4) Requirement already satisfied: pandas in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (1.3.5) Requirement already satisfied: cloudpickle in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (2.2.0) Requirement already satisfied: packaging>20.9 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (21.3) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from packaging>20.9->shap) (3.0.9) Requirement already satisfied: colorama in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from tqdm>4.25.0->shap) (0.4.6) Requirement already satisfied: importlib-metadata in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from numba->shap) (5.0.0) Requirement already satisfied: setuptools in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from numba->shap) (65.5.0) Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from numba->shap) (0.39.1) Requirement already satisfied: python-dateutil>=2.7.3 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from pandas->shap) (2.8.2) Requirement already satisfied: pytz>=2017.3 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from pandas->shap) (2022.6) Requirement already satisfied: threadpoolctl>=2.0.0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-learn->shap) (3.1.0) Requirement already satisfied: joblib>=0.11 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-learn->shap) (1.2.0) Requirement already satisfied: six>=1.5 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from python-dateutil>=2.7.3->pandas->shap) (1.16.0) Requirement already satisfied: zipp>=0.5 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from importlib-metadata->numba->shap) (3.10.0) Requirement already satisfied: typing-extensions>=3.6.4 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from importlib-metadata->numba->shap) (4.4.0)
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
from IPython.display import Markdown, display
import seaborn as sns
import sklearn.model_selection
import sklearn.metrics
import sklearn.datasets
import sklearn.ensemble
import sklearn.preprocessing
from sklearn.metrics import accuracy_score
import xgboost
from xgboost import plot_importance
import lime
import lime.lime_tabular
import shap
from shap.plots import _waterfall
np.random.seed(1)
For this assignment you will be using the Pima Indians Diabetes Database We use the preprosessed version published in the kaggel website (here)
Load the dataset
#Note you have to mode the csv file to the apporpriate folder or change the path "/content/data/diabetes.csv" in the code below
#Load data from CSV file to a data frame
df_diabetes = pd.read_csv("data/diabetes.csv")
df_diabetes.shape
(768, 9)
df_diabetes.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 768 entries, 0 to 767 Data columns (total 9 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Pregnancies 768 non-null int64 1 Glucose 768 non-null int64 2 BloodPressure 768 non-null int64 3 SkinThickness 768 non-null int64 4 Insulin 768 non-null int64 5 BMI 768 non-null float64 6 DiabetesPedigreeFunction 768 non-null float64 7 Age 768 non-null int64 8 Outcome 768 non-null int64 dtypes: float64(2), int64(7) memory usage: 54.1 KB
Here we see we have no null values, the dataset consists of 8 columns of medical / personal data which might indicate that a person is at risk for diabetes which is indicated by the 8th column. In total, we have data of 768 patients.
[Optional - Not Graded ]
We recommand you make some data exploration to get familiar with the attributes and the values
This part is not graded but if you provide great data visualization and exploration ideas you might get a bonus :) </small>
df_diabetes.head()
| Pregnancies | Glucose | BloodPressure | SkinThickness | Insulin | BMI | DiabetesPedigreeFunction | Age | Outcome | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | 6 | 148 | 72 | 35 | 0 | 33.6 | 0.627 | 50 | 1 |
| 1 | 1 | 85 | 66 | 29 | 0 | 26.6 | 0.351 | 31 | 0 |
| 2 | 8 | 183 | 64 | 0 | 0 | 23.3 | 0.672 | 32 | 1 |
| 3 | 1 | 89 | 66 | 23 | 94 | 28.1 | 0.167 | 21 | 0 |
| 4 | 0 | 137 | 40 | 35 | 168 | 43.1 | 2.288 | 33 | 1 |
rows = 3
columns = 3
fig, axes = plt.subplots(rows, columns, figsize=(8, 8))
fig.set_tight_layout(True)
i, j = 0, 0
all_columns = df_diabetes.columns
for c in all_columns:
if c == 'Outcome':
sns.countplot(data=df_diabetes, x='Outcome')
sns.boxplot(data=df_diabetes, y=c, x='Outcome', ax=axes[i][j])
axes[i][j].set_title(c)
j += 1
if j % 3 == 0:
i += 1
j = 0
In total we deleted 44 records with incorrectly measured features.
#Here we drop the 0 values that we have just observed in the dataset
print(df_diabetes.shape)
df_diabetes.drop(df_diabetes[(df_diabetes.Glucose <= 0) | (df_diabetes.BMI <= 0) | (df_diabetes.BloodPressure <= 0)].index, inplace=True)
df_diabetes.reset_index(inplace=True, drop=True)
print(df_diabetes.shape)
(724, 9) (724, 9)
df_diabetes.head()
| Pregnancies | Glucose | BloodPressure | SkinThickness | Insulin | BMI | DiabetesPedigreeFunction | Age | Outcome | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | 6 | 148 | 72 | 35 | 0 | 33.6 | 0.627 | 50 | 1 |
| 1 | 1 | 85 | 66 | 29 | 0 | 26.6 | 0.351 | 31 | 0 |
| 2 | 8 | 183 | 64 | 0 | 0 | 23.3 | 0.672 | 32 | 1 |
| 3 | 1 | 89 | 66 | 23 | 94 | 28.1 | 0.167 | 21 | 0 |
| 4 | 0 | 137 | 40 | 35 | 168 | 43.1 | 2.288 | 33 | 1 |
From the boxplots we can observe that some values are 0 that should not be 0. Namely, there are some 0 values for the Glucose, BMI and Blood Pressure features which cannot be possible. Furthermore, people with high risk of diabetes generally have increased age, pregnancies, bmi, insulin, glucose and blood pressure levels. It is an unbalanced dataset, since there are generally fewer people with risk of diabetes (around 290) and 500 people with no risk of diabetes.
""#We can use a heatmap to determine features that have highest correlation with SHARE_HIGH.
# heatmap of correlations
# Create plot
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
fig.set_tight_layout(True)
# Compute correlation matrix
corr = df_diabetes.corr()
# Create upper triangular matrix to mask the upper triangular part of the heatmap
corr_mask = np.triu(np.ones_like(corr, dtype=bool))
# Generate a custom diverging colormap (because it looks better)
corr_cmap = sns.diverging_palette(230, 20, as_cmap=True)
sns.heatmap(corr, mask = corr_mask, cmap=corr_cmap, annot=True,square = True, linewidths=.5, ax = axes[0][0])
axes[0][0].set_title("Correlation plot of dataset features")
sns.kdeplot(data = df_diabetes, x='BMI', y='BloodPressure', ax=axes[0][1], fill=True)
axes[0][1].set_title("Scatterplot of BMI against blood pressure")
df_low = df_diabetes[df_diabetes.Glucose <= df_diabetes.Glucose.mean()]
df_high = df_diabetes[df_diabetes.Glucose > df_diabetes.Glucose.mean()]
sns.countplot(df_low, x='Outcome', ax= axes[1][0])
sns.countplot(df_high, x='Outcome', ax= axes[1][1])
axes[1][0].set_title("Outcome of people with <= average glucose levels")
axes[1][1].set_title("Outcome of people with > average glucose levels")
plt.show()
From the plots we can see that Glucose and BMI are the most important features for indicating a high risk of diabetes (highest correlation). Furthermore we can observe that a higher BMI generally correlates with a higher Blood Pressure and that people with a lower than average Glucose have way less chance of diabetes compared to high Glucose levels which makes sense.
Note: the data set we selected for this Assignment has no categorical featuers. this means you will not need to encode categorical features nor to use any encoders. This will make this assignment more simpke than the Lab. This will also allow you to generate explanations from both LIME and SHAP using the same model. You do not need to retrain the model we provided.
#Get Outcome lables
labels = df_diabetes['Outcome']
#Get features
data = df_diabetes.drop('Outcome', axis=1)
# create a train/test split
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(data, labels, train_size=0.70, random_state=7)
print("Train shape: ", X_train.shape)
print("Test shape: ", X_test.shape)
Train shape: (506, 8) Test shape: (218, 8)
# Fit the model
gbtree_model = xgboost.XGBClassifier(learning_rate=0.01)
gbtree_model.fit(X_train, y_train)
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None,
colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1,
early_stopping_rounds=None, enable_categorical=False,
eval_metric=None, gamma=0, gpu_id=-1, grow_policy='depthwise',
importance_type=None, interaction_constraints='',
learning_rate=0.01, max_bin=256, max_cat_to_onehot=4,
max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1,
missing=nan, monotone_constraints='()', n_estimators=100,
n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0,
reg_alpha=0, reg_lambda=1, ...)
# Make predictions on test data
y_pred = gbtree_model.predict(X_test)
# Calcualte accuracy on the test real outcomes
accuracy = sklearn.metrics.accuracy_score(y_test, y_pred)
print("Accuracy: %.2f%%" % (accuracy * 100.0))
Accuracy: 77.06%
To Code (5 points)
A. Create 3 different feature importance plot using Intrinsic Explainability of your model
Hint: Take a look at xgboost [documentation](https://xgboost.readthedocs.io/en/latest/python/python_api.html#).
# Feature importance plot 1
gbtree_model.get_booster().feature_names = list(df_diabetes.columns)[:-1]
plt.rcParams["figure.figsize"] = (8,8)
xgboost.plot_importance(gbtree_model, importance_type="weight")
plt.title("Feature Importance using weight as measurement criteria")
plt.show()
xgboost.plot_importance(gbtree_model, importance_type="cover")
plt.title("Feature Importance using cover as measurement criteria")
plt.show()
xgboost.plot_importance(gbtree_model, importance_type="gain")
plt.title("Feature Importance using gain as measurement criteria")
plt.show()
# Feature importance plot 3
To Answer (15 points)
B. What kind of explanations does these plots provide think about three or for creteria?
C. Who can use these explanations in our context which stakeholder(s) and why (what purpose)?
D. Observe the three plots and evaluate the over all consistency of these intrinsic explanations (are they consistant? what is missing?)
B. The feature importance can be defined in three different ways: using weight, gain or cover.
C. Medical experts can use these feature importances for more insight in why the model would predict this patient as high-risk / low risk. For example, the model can predict that a patient with a high glucose level has high risk of diabetes. Then by looking at the feature importances, the medical expert can inform and explain why the patient has a high risk of diabetes.
D.
The cover and gain feature importance seem to match the most, having Glucose, Age and BMI in the same order of feature importance. Glucose seems to be the most important feature by far for both, also confirmed by the correlation plot in the data exploration part above. Interestingly, the gain measure matches the order of importance given by the correlations the most. The Weight measure is quite different having BMI and DiabetesPedigreeFunction as the two most important features which are lower ranked on the other two measures.
Unfortuately, there are also some limitations of these feature importances. They do not provide insight on the direction of the impact sign (+/-) and they only provide relative importances (not absolute). They are also not that stable, retraining can give different feature importances.
To Code
A. Implement a LIME explainer (10 points)
B. Use the to generate explanations (with vizualization) on 4 datapoints (10 points)
Note: remember, our dataset does not have categorical features. You do not need to specify any when creating the LIME explainer. Feature names and class names are, however, needed. Take a closer look at the documentation (link above)
explainer = lime.lime_tabular.LimeTabularExplainer(X_train.values, feature_names=list(df_diabetes.columns)[:-1], class_names=['Low Risk','High Risk'])
#Here we have the first 5 test points with corresponding label
test_points = df_diabetes.iloc[list(X_test.index)]
test_points.head()
| Pregnancies | Glucose | BloodPressure | SkinThickness | Insulin | BMI | DiabetesPedigreeFunction | Age | Outcome | |
|---|---|---|---|---|---|---|---|---|---|
| 321 | 1 | 130 | 70 | 13 | 105 | 25.9 | 0.472 | 22 | 0 |
| 618 | 11 | 127 | 106 | 0 | 0 | 39.0 | 0.190 | 51 | 0 |
| 376 | 4 | 95 | 64 | 0 | 0 | 32.0 | 0.161 | 31 | 1 |
| 94 | 0 | 125 | 96 | 0 | 0 | 22.5 | 0.262 | 21 | 0 |
| 663 | 2 | 127 | 46 | 21 | 335 | 34.4 | 0.176 | 22 | 0 |
exp = explainer.explain_instance(X_test.values[0], gbtree_model.predict_proba, num_features=5)
exp.show_in_notebook(show_table=True, show_all=True)
exp = explainer.explain_instance(X_test.values[1], gbtree_model.predict_proba, num_features=5)
exp.show_in_notebook(show_table=True, show_all=True)
exp = explainer.explain_instance(X_test.values[2], gbtree_model.predict_proba, num_features=5)
exp.show_in_notebook(show_table=True, show_all=True)
exp = explainer.explain_instance(X_test.values[7], gbtree_model.predict_proba, num_features=5)
exp.show_in_notebook(show_table=True, show_all=True)
To Answer
C. What kind of explanations does these plots provide think about three or for creteria? (5 ponts)
D. Briefely describe how we can read each of the 4 explanations for the points you selected (10 points)
E. Who can use these explanations in our context which stakeholder(s) and why (what purpose)? (5 ponts)
C. Lime stands Local Interpretable Model Agnostic Explanation. The local aspect means that it is used to explain individual predictions of a machine learning model. It provides explainations which medical features contribute to predicting the patient has high / low risk and how much each feature contributes to the decision making.
D. Let's walk through all our 4 datapoints.
Vis 0:
The true outcome is Low-Risk. On the left we can see that the model predicts with 0.78 probability that the patient has low risk of having diabetes. Then in the middle of the plot we can observe that this is mainly due to the BMI <= 27.73 and AGE <= 24.00 of the patient. This seems to be a young patient with a healthy weight and so is classified as low risk on having diabetes.
Vis 1:
The true outcome of this person is Low-Risk. On the left we can see that the model is quite unsure and predicts with 0.57 probability that the patient has low risk of having diabetes. Then in the middle of the plot we can observe the BMI > 36.80 and AGE > 41.00 of the patient. The higher age and high BMI contribute to classifying this patient as high-risk, but blood pressure and glucose levels seem normal so eventually model classifies it as low-risk which is indeed correct.
Vis 2:
The true outcome of this person is High-Risk. On the left we can see that the model is quite unsure and predicts with 0.67 probability that the patient has low risk of having diabetes. The person is somewhat overweight and has high glucose levels, which are indicators for diabetes. Then in the middle of the plot we can observe that the model thinks Glucose <= 99.0 is allocated to being low-risk. Because of this main important feature here, the model eventually predicts it as low-risk but does it wrong since it is actually high risk.
Vis 3:
The true outcome of this person is High-Risk. On the left we can see that the model is quite sure and predicts with 0.72 probability that the patient has high risk of having diabetes. The person is heavily overweight and has low glucose levels and blood pressure. Then in the middle of the plot we can observe that the model indeed uses Glucose and BMI as features important to predicting the model as high risk. Even though this is a young person, the model got it correct that this person is of high risk obtaining diabetes.
E. Using Lime has the advantage that we are able to analyse locally each persons' medical data and see which features are important in the classification task of the model. Medical experts and doctors can use this to explain results with the patients and exactly pinpoint which medical feature is the reason of high / low risk on diabetes and doctors can so motivate why this person would be selected or not for the clinical drug trial.
To Code
A. Implement a SHAP explainer (5 points)
B. Use the to generate explanations: (15)
shap.initjs()
#Create explainer and run on test data
explainer = shap.TreeExplainer(gbtree_model)
shap_values = explainer.shap_values(X_test)
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
figure = plt.figure()
shap.summary_plot(shap_values, X_test, plot_type='bar')
This plot shows the global list of important features, from most significant to the least significant one. Glucose is the one with the most predictive power according to the model and age is the second feature with the most predictive power.
shap.summary_plot(shap_values, X_test)
The above plot shows the global list of most important features from top to bottom. Now each dot represents the feature value for a single data instance, a blue dot indicates a low feature value and a red dot indicates a high value. From this we can observe that for glucose a high value would indicate for a higher positive outcome of having diabetes. The same also holds in general for age, BMI and DiabetesPedigreeFunction.
The below explanations shows features each contributing to push the model output from the base value (the average model output over the training dataset we passed) to the model output. Features pushing the prediction higher are shown in red, those pushing the prediction lower are in blue
_waterfall.waterfall_legacy(explainer.expected_value, shap_values[0], X_test.iloc[0,:])
The true outcome is Low-Risk. Beacuse of the low BMI and low age of this patient, the prediction is pushed to the left and so this patient is predicted as low risk.
_waterfall.waterfall_legacy(explainer.expected_value, shap_values[1], X_test.iloc[1,:])
The true outcome of this person is Low-Risk. The model is a bit unsure here, high BMI and AGE are pushing the prediction to the right (high-risk) but other factors are pushing it back to the left. Eventually the model predicts -0.298 which is still lower than 0 and so is classified as low risk.
_waterfall.waterfall_legacy(explainer.expected_value, shap_values[2], X_test.iloc[2,:])
The true outcome of this person is High-Risk. The model is a bit unsure here, mainly the low glucose is pushing the prediction far to the left, while the person is a bit older and overweight which pushes it to the right.
_waterfall.waterfall_legacy(explainer.expected_value, shap_values[7], X_test.iloc[7,:])
The true outcome of this person is High-Risk. The model is quite certain here mainly because of the highered glucose levels, which pushes it a lot to the right and ends up with a correct high-risk prediction.
#Get test points for persons with > 2 pragnances
test_points = df_diabetes.iloc[list(X_test.index)]
pregnant_df = test_points[test_points.Pregnancies > 2]
pregnant_y = pregnant_df['Outcome']
pregnant_x = pregnant_df.drop('Outcome', axis=1)
print(test_points.shape)
print(pregnant_x.shape)
(218, 9) (116, 8)
shap_p = explainer.shap_values(pregnant_x)
shap.force_plot(explainer.expected_value, shap_p, pregnant_x, show=True)
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
This forceplot shows a combined plot of people who have had at least two pregnancies. It shows the same as the previous waterfall plot but then in an interactive environment using JavaScript where we can see which main features contributed to the classification.
To Answer
C. Brefely Describe what information can we get from each of the plots (B1, B2, and B3) (10 points)
D. Who can use each type of explanations and for what purposes? (10 points)
C. The description of the plots are located below each plot.
D. It is very important for physicians to check if the model is working predicting people to be low-risk diabetes correctly, since they are then viable for the drug trial. If a person would still have diabetes while the model predicts it to be low risk this can cause major issues since the drug can have adverse effects for people with diabetes. So we are interested in False Negatives (recall metric). By incorporating the explainability of the model given above by looking at a specific persons explainability, the physician can use this to explain results with the patients and exactly pinpoint which medical feature is the reason of high / low risk on diabetes and doctors can so motivate why this person would be selected or not for the clinical drug trial. Development team and hospital management might be interested in which global features have the most impact on the classification outcome, for this, plots in B1 are relevant.